import pandas as pd
import numpy as np
import torch
from ledidi import Ledidi
from regression import SeqDataset
from enformer_pytorch.data import str_to_one_hot

INDEX_TO_BASE_HASH = {i:base for i, base in enumerate(["A", "C", "G", "T"])}


def ISM_at_pos(seq, pos):
    outputs = []
    alt_bases = ["A", "C", "G", "T"]
    alt_bases.remove(seq[pos])
    for base in alt_bases:
        mutated_seq = list(seq)
        mutated_seq[pos] = base
        outputs.append("".join(mutated_seq))
    return outputs


def ISM(seq):
    outputs = []
    for pos in range(len(seq)):
        outputs.extend(ISM_at_pos(seq, pos, drop_ref=True))
    return outputs


class MaximizingLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, targets):
        if inputs.dim() == 2:
            return -inputs.mean(axis=1)
        else:
            return -inputs


class MinimizingLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, targets):
        if inputs.dim() == 2:
            return inputs.mean(axis=1)
        else:
            return inputs


class DesignLoss(nn.Module):
    def __init__(
        self,
        to_max=[0],
        to_min=None,
    ):
        super().__init__()
        self.to_max = to_max
        self.to_min = to_min

    def forward(self, inputs):
        inputs = inputs.cpu()
        max_loss = MaximizingLoss()(inputs[:, to_max])
        min_loss = 0
        if to_min is not None:
            min_loss = MinimizingLoss()(inputs[:, to_min])
        return max_loss + min_loss


# Directed Evolution
def evolve(
    start_seqs,
    model,
    to_max=[0],
    to_min=[]
    max_iter=10,
    device=0,
    num_workers=1,
):
    # Create loss function
    loss_func = DesignLoss(to_max, to_min)
    outputs = pd.DataFrame(columns=["iter", "seq"])
    best_loss = np.Inf

    # Iterate
    for i in range(1, max_iter + 1):

        # Create dataset
        dataset = SeqDataset(start_seqs)

        # Predict on sequences
        preds = model.predict_on_dataset(
            dataset=dataset,
            device=device,
            num_workers=num_workers,
        ) 

        # Save predictions and loss values
        curr_output = pd.DataFrame.from_dict(
            {
                "iter": str(i),
                "seq": dataset.seqs,
            }
        )
        # Concatenate outputs
        outputs = pd.concat([outputs, curr_output])

        # Calculate losses
        loss = loss_func(preds).cpu().detach().numpy()

        # Select best sequence from current iteration
        best_idx = loss.argmin()
        best = curr_output.iloc[best_idx, :]
        start_seqs = ISM(best.seq, positions=positions, drop_ref=True)

    return outputs.reset_index(drop=True)


# Ledidi
def ledidi(
    start_seq,
    model,
    to_max=[0],
    to_min=[],
    max_iter=20000,
    device=0,
    num_workers=1,
    **kwargs,
):
    # Create loss function
    loss_func = DesignLoss(to_max, to_min, reduction="mean",)

    X = str_to_one_hot(start_seq).swapaxes(1, 2).to(torch.device(device))
    model = model.to(torch.device(device))
    designer = Ledidi(
        model,
        X[0].shape,
        output_loss=loss_func,
        max_iter=max_iter,
        target=None,
        **kwargs,
    ).to(torch.device(device))

    X_hat = designer.fit_transform(X, None).cpu()
    model = model.cpu()
    values, indices = X_hat.max(axis=1)
    return ["".join([INDEX_TO_BASE_HASH[i.tolist()] for i in idx]) for idx in indices]


def evolve_specific(start_seqs, cell_type_idx, model, max_iter=10, device=0):
    df = pd.DataFrame()
    cell_type_idxs = list(range(model.n_tasks))
    min_type_idxs = [x for x in cell_type_idxs if x != cell_type_idx]
    
    for i, seq in enumerate(start_seqs):
        curr_df = grelu.design.evolve.evolve([seq], model, loss_types=['max', 'min'],
                                        tasks=[cell_type_idx, min_type_idxs],
                                        max_iter=max_iter, devices=[device])
        curr_df = curr_df[['seq', 'iter']
        curr_df['start_seq'] = i
        df = pd.concat([df, curr_df])
    return df


def evolve_high(start_seqs, model, max_iter=10, device=0):

    df = pd.DataFrame()
    
    for i, start_seq in enumerate(start_seqs):

        curr_df = grelu.design.evolve.evolve([start_seq], model, loss_types=['max'], max_iter=max_iter, devices=[device])
        curr_df = curr_df[['seq', 'iter']
        curr_df['start_seq'] = i
        df = pd.concat([df, curr_df])

    return df


def ledidi_specific(start_seqs, model, cell_type_idx, max_iter=1000, device=0, l=20, lr=3e-3):
    df = pd.DataFrame()

    cell_type_idxs = list(range(model.n_tasks))
    min_type_idxs = [x for x in cell_type_idxs if x != cell_type_idx]

    for i, start_seq in enumerate(start_seqs):
    
        curr_df = pd.DataFrame({'seq': grelu.design.evolve.ledidi(
            start_seq, model, loss_types=['max', 'min'], 
            tasks=[cell_type_idx, min_type_idxs],
            max_iter=max_iter, device=device, l=l, lr=lr,
        )})
        curr_df['start_seq'] = i
        df = pd.concat([df, curr_df])
    return df


def ledidi_high(start_seqs, model, max_iter=1000, device=0, l=20, lr=3e-3):
    df = pd.DataFrame()

    for i, start_seq in enumerate(start_seqs):

        curr_df = pd.DataFrame({'seq': grelu.design.evolve.ledidi(
            start_seq, model, loss_types=['max'], max_iter=max_iter, device=device, l=l, lr=lr,
        )})
        curr_df['start_seq'] = i
        df = pd.concat([df, curr_df])

    return df


def match(new, gen, pred_cols):
    cp = new.copy()
    cp['SeqID'] = [str(x) for x in range(len(cp))]

    matched = pd.DataFrame()

    for i in range(len(gen)):

        # Get predicted expression
        gen_preds = [gen[col].iloc[i] for col in pred_cols]

        # Find closest match among new sequences
        mae = np.sum([np.abs(cp[col] - gen[col].iloc[i]) for col in pred_cols], 0)
        match_idx = np.argmin(mae)
        
        # Store
        matched = pd.concat([matched, cp.iloc[[match_idx],]])

        # Drop the start sequence
        start_seq = cp.start_seq.iloc[match_idx]
        cp = cp[cp.start_seq != start_seq]

    return matched.reset_index(drop=True)[['Sequence', 'Group'] + pred_cols]